{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 处理数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "datasets_path = \"/root/datasets\"\n",
    "# 读取数据并保存在当前文件夹\n",
    "# datasets = load_dataset(datasets_path,split={\"train\":\"train[:20]\",\"test\":\"test[:20]\"})\n",
    "datasets = load_dataset(datasets_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "# pretrained_model_path = \"D:/hugging_face/models/bert-base-cased\"\n",
    "pretrained_model_path = \"/root/models/distilbert-base-uncased\" #TODO\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True,max_length= 512) # max_length 必须指定值\n",
    "    \n",
    "\n",
    "tokenized_datasets = datasets.map(tokenize_function, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 移除 text 列, 因为模型不接受原始文本作为输入:\n",
    "tokenized_datasets = tokenized_datasets.remove_columns([\"text\"])\n",
    "# 将 label 列重命名为 labels, 因为模型期望参数的名称为 labels:\n",
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置数据集的格式以返回 PyTorch 张量\n",
    "tokenized_datasets.set_format(\"torch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets['train']['labels']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_dataloader = DataLoader(tokenized_datasets['train'], shuffle=True, batch_size=16)\n",
    "eval_dataloader = DataLoader(tokenized_datasets['test'], batch_size=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "# 加载模型并指定期望的标签数\n",
    "model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_path, num_labels=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=5e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_epochs = 2  #TODO\n",
    "num_training_steps = num_epochs * len(train_dataloader)\n",
    "lr_scheduler = get_scheduler(\n",
    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "model.train()\n",
    "\n",
    "# 初始化一个列表来保存损失值\n",
    "losses = []\n",
    "\n",
    "# 存储学习率的列表\n",
    "learning_rates = []\n",
    "\n",
    "# 用于比较 得到最好的模型参数\n",
    "best_loss = 1e10\n",
    "\n",
    "# 设置最小误差的模型保存路经\n",
    "save_model_path = \"./my_awesome_model\"\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for batch in train_dataloader:\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        losses.append(loss)\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        current_lr = lr_scheduler.get_last_lr()[0]  # 获取当前学习率\n",
    "        learning_rates.append(current_lr)\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "    if (epoch+1) % 1 == 0:\n",
    "        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')  \n",
    "    if loss.item() < best_loss:\n",
    "        best_loss = loss.item()\n",
    "        tokenizer.save_pretrained(save_model_path)\n",
    "        model.save_pretrained(save_model_path, safe_serialization=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
    "from datasets import load_dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_path = \"/root/datasets/imdb\" #TODO\n",
    "\n",
    "datasets = load_dataset(datasets_path)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"./my_awesome_model\")\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True,max_length= 512) # max_length 必须指定值\n",
    "\n",
    "tokenized_datasets = datasets.map(tokenize_function, batched=True)\n",
    "\n",
    "\n",
    "# 移除 text 列, 因为模型不接受原始文本作为输入:\n",
    "tokenized_datasets = tokenized_datasets.remove_columns([\"text\"])\n",
    "# 将 label 列重命名为 labels, 因为模型期望参数的名称为 labels:\n",
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
    "# 设置数据集的格式以返回 PyTorch 张量\n",
    "tokenized_datasets.set_format(\"torch\")\n",
    "\n",
    "\n",
    "eval_dataloader = DataLoader(tokenized_datasets['test'], batch_size=16)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DistilBertForSequenceClassification(\n",
       "  (distilbert): DistilBertModel(\n",
       "    (embeddings): Embeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (transformer): Transformer(\n",
       "      (layer): ModuleList(\n",
       "        (0-5): 6 x TransformerBlock(\n",
       "          (attention): MultiHeadSelfAttention(\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "          )\n",
       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (ffn): FFN(\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (activation): GELUActivation()\n",
       "          )\n",
       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n",
       "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
       "  (dropout): Dropout(p=0.2, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"./my_awesome_model\")\n",
    "\n",
    "model.to(device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf_metrics = evaluate.combine([\"/root/metrics/accuracy\", \"/root/metrics/f1\", \"/root/metrics/precision\", \"/root/metrics/recall\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e424d1b0623c45f4b65f5b7d6c9f9e1b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1563 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.92976,\n",
       " 'f1': 0.9294439087110254,\n",
       " 'precision': 0.9336454633516306,\n",
       " 'recall': 0.92528}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "progress_bar = tqdm(range(1563))\n",
    "\n",
    "for i, batch in enumerate(eval_dataloader):\n",
    "    batch = {k: v.to(device) for k, v in batch.items()}\n",
    "    outputs = model(**batch)\n",
    "    logits = outputs.logits\n",
    "    predicted_class_id = logits.argmax(-1)\n",
    "    prediction = predicted_class_id.tolist()\n",
    "    reference = batch['labels'].tolist()\n",
    "    clf_metrics.add_batch(predictions=prediction, references=reference)\n",
    "    progress_bar.update(1)\n",
    "clf_metrics.compute()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
